import copy
import os
import random
from glob import glob
from multiprocessing import Pool
import cProfile

import numpy as np
from natsort import natsorted
import warnings

warnings.filterwarnings("ignore")
import sys

from dn_lr.utils import calculate_evaluation_metrics, create_directory, print_and_save_eval_metrics

from dn_lr.dn.model_class_scratch_lr.logistic_regression.model import pre_compute_evidence_part_for_one_example, \
    compute_true_label_part_for_one_example, _sigmoid_function


def prepare_data_for_sampling(this_true_label_index, this_sample, this_evidence):
    """
    Prepare the data for sampling
    :param this_true_label_index: index of the y label for LR
    :param this_sample: current sample
    :param this_evidence: evidence from CNN
    :return: final input x for the model
    """
    all_other_actual_labels_from_sample = np.delete(this_sample, this_true_label_index)
    this_x = np.concatenate((all_other_actual_labels_from_sample, this_evidence))
    return this_x


# #Sample all variables in one iteration
def sample_one_instance(args):
    """
    Do one sampling step from the dependency graph
    :param models: All trained models
    :param num_true_labels: Number of true labels
    :param this_evidence: evidence from CNN
    :return: final sample probabilities to be used for mixture estimator
    """
    models, num_true_labels, this_evidence, var_sequence, this_sample, logits_for_evidence_part_for_all_models = args
    this_sample_probs = np.zeros((num_true_labels))
    # this_sample = np.random.binomial(n=1, p=0.5, size=[num_true_labels])
    for this_true_label_index in var_sequence:
        this_model = models[this_true_label_index]
        this_x = np.concatenate((this_sample[:this_true_label_index], this_sample[this_true_label_index + 1:]))
        this_sample_logit = compute_true_label_part_for_one_example(this_x, this_model['weights'], num_true_labels)
        # Add both the logits and take sigmoid
        this_sample_prob = _sigmoid_function(
            np.add(this_sample_logit, logits_for_evidence_part_for_all_models[this_true_label_index]))
        this_sample_probs[this_true_label_index] = this_sample_prob
        random_num = random.random()
        if random_num <= this_sample_prob:
            this_sample_value = 1
        else:
            this_sample_value = 0
        this_sample[this_true_label_index] = this_sample_value
    return this_sample_probs, this_sample


# #Sample all variables in one iteration
def sample_for_one_example_seq(args):
    num_samples, num_true_labels, models, each_evidence, example_index = args
    this_initial_sample = np.random.binomial(n=1, p=0.5, size=[num_true_labels])
    this_sample = this_initial_sample
    this_example_samples = copy.deepcopy(this_sample)
    var_sequence = np.arange(num_true_labels)
    logits_for_evidence_part_for_all_models = []
    if example_index % (100) == 0:
        print(f"We are at example number {example_index}")
    for this_true_label_index in var_sequence:
        this_model = models[this_true_label_index]
        this_x = each_evidence
        this_sample_logit = pre_compute_evidence_part_for_one_example(this_x, this_model['weights'], this_model['bias'],
                                                                      num_true_labels)
        logits_for_evidence_part_for_all_models.append(this_sample_logit)
    for each_sample in range(num_samples):
        this_sample_prob_1, this_sample = sample_one_instance((models, num_true_labels, each_evidence, var_sequence,
                                                               this_sample, logits_for_evidence_part_for_all_models))
        this_example_samples = np.vstack((this_example_samples, this_sample_prob_1))
    this_sample_estimate = np.mean(this_example_samples, axis=0)
    return example_index, this_sample_estimate


def gibbs_sampling(valid_actual_output, valid_predictions, model_dicts, num_samples, num_pools, model_save_path,
                   logger):
    num_examples, num_true_labels = valid_actual_output.shape
    pool = Pool(processes=num_pools)
    inputs = [[] for _ in range(num_examples)]
    for example_index, each_evidence in enumerate(valid_predictions):
        inputs[example_index] = (num_samples, num_true_labels, model_dicts, each_evidence, example_index)

    output = pool.map(sample_for_one_example_seq, inputs)
    pool.close()
    pool.join()
    final_probs = [[] for each_example in range(num_examples)]
    for each_index, each_prob in output:
        final_probs[each_index] = each_prob
    outputs = np.array(final_probs)
    probs = copy.deepcopy(outputs)

    for threshold in [0.1, 0.2, 0.3, 0.5]:
        eval_metrics = calculate_evaluation_metrics(valid_actual_output, outputs, threshold)
        results_dir = f'{model_save_path}/dn/'
        create_directory(results_dir)
        output_filename = f'{results_dir}threshold_{threshold}_results.csv'
        print(eval_metrics)
        print_and_save_eval_metrics(eval_metrics, output_filename, logger)
    output_path = f'{model_save_path}/dn/'
    output_data_location = f'{output_path}/threshold_{threshold}'
    create_directory(output_data_location)
    np.savetxt(f'{output_data_location}/test.output', probs, delimiter=",")
